library(GenericML)
library(pacman)
pacman::p_load(
  here,
  tidyverse,
  haven,
  readr,
  grf,
  kableExtra,
  knitr,
  stringr,
  dplyr
)

### IMPORTANT - Script needs to be rerun for ever outcome and outcomes need to be changed in line 29 and 68!
### 1. Data Generation ----
set.seed(31684591)

df_all <- read_dta("C:/Users/katharina.fietz/GIGA/ReUsSITE - WP1 Data - CI Data folders/Internal - data folder/4_dofiles/Paper II/ind/replication/02_data/4_individual_ano_reg_3.dta") 

df <- df_all %>%
  dplyr::select(wrcon, aminwage, sosec, assignment, wrcon, empquality_ind, assignment, abj, a30, man, uni, old, exp, sup, rel, id,
                strata_all_coll, educ_tert_base, ag_ano, empquality_ind_2021, sosec_2021, missing_sosec_2021, missing_empquality_ind_2021, aminwage_2021, missing_aminwage_2021, wrcon_2021, missing_wrcon_2021, wave) %>%
  mutate(across(c(strata_all_coll, ag_ano), as.factor)) %>%
  filter(complete.cases(.))


# Extract variables for GenericML
Y <- as.numeric(df$aminwage)                   # outcome variable
D <- as.numeric(df$assignment)                  # treatment assignment (should be 0/1)
covariates <- c("abj", "a30", "man", "uni", "exp", "sup", "rel", "old")
Z <- df %>% dplyr::select(all_of(covariates)) %>%
  mutate(across(everything(), as.numeric)) %>%  # ensure all columns are numeric
  as.matrix()

### 2. Prepare the arguments for GenericML() ----

## quantile cutoffs for the GATES grouping of the estimated CATEs
quantile_cutoffs <- c(0.25, 0.5, 0.75) # 25%, 50%, and 75% quantiles

## specify the learner of the propensity score (non-penalized logistic regression here).
# Propensity scores can also directly be supplied.
learner_propensity_score <- "mlr3::lrn('glmnet', lambda = 0, alpha = 1)"

# specify the considered learners of the BCA and the CATE (here: lasso, random forest, and SVM)
learners_GenericML <- c("lasso",
                        "mlr3::lrn('ranger', num.trees = 100)",
                        "mlr3::lrn('svm')")

# specify the data that shall be used for the CLAN
# here, we use all variables of Z and uniformly distributed random noise
Z_CLAN <- cbind(Z)

# specify the number of splits (many to rule out seed-dependence of results)
num_splits <- 200

# specify if a HT transformation shall be used when estimating BLP and GATES
HT <- FALSE

# Create dummy variables for each factor column using model.matrix
dummy_vars <- df %>%
  dplyr::select(strata_all_coll, ag_ano) %>%
  model.matrix(~ . - 1, data = .) %>%
  as.data.frame()

# Combine the dummy variables with the original dataframe excluding the original factor columns
X <- df %>%
  dplyr::select(aminwage_2021, missing_aminwage_2021, wave) %>%
  bind_cols(dummy_vars)

# Convert to matrix
X1_controls_mat <- as.matrix(X)

## A list controlling the variables that shall be used in the 
# matrix X1 for the BLP and GATES regressions.
X1_BLP   <- setup_X1(covariates = X1_controls_mat)
X1_GATES <- setup_X1(covariates = X1_controls_mat)

# consider differences between group K (most affected) with group 1
diff_GATES <- setup_diff(subtract_from = "most",
                         subtracted = 1)
diff_CLAN  <- setup_diff(subtract_from = "most",
                         subtracted = 1)

# specify the significance level
significance_level <- 0.05

## specify minimum variation of predictions before
# Gaussian noise with variance var(Y)/20 is added.
min_variation <- 1e-05

## specify which estimator of the error covariance matrix shall be used 
# in BLP and GATES (standard OLS covariance matrix estimator here)
vcov_BLP   <- setup_vcov()
vcov_GATES <- setup_vcov()

# ensure that GATES parameters are monotonically increasing (required by theory)
monotonize <- TRUE

# the package allows for stratified sampling (don't use it here by keeping the arguments empty)
stratify <- setup_stratify()

# specify that there are no external weights to be used in the analysis
external_weights <- NULL

# specify the proportion of samples that shall be selected in the auxiliary set
prop_aux <- 0.5

# specify whether or not the splits and auxiliary results of the learners shall be stored
store_splits   <- TRUE
store_learners <- FALSE # to save memory

# parallelization options
parallel  <- TRUE
num_cores <- 4      # 8 cores
seed      <- 123456
## NB: Note that the number of cores as well as your type of operating system
# (Unix vs. Windows) influences the random number stream.
# Thus, different choices of `num_cores` may lead to different results. 
# Results of parallel processes are reproducible across all Unix systems,
# but might deviate on Windows systems.



### 3. Run the GenericML() function with these arguments ----
## runtime: ~40 seconds with R version 4.4.1 on a Dell XPS 13 9340 
# (CPU: Intel Core Ultra 7 165H × 22, RAM: 32GB), running on Ubuntu 24.04 LTS. 
# Returns a GenericML object.
genML <- GenericML(Z = Z, D = D, Y = Y,
                   learner_propensity_score = learner_propensity_score,
                   learners_GenericML = learners_GenericML,
                   num_splits = num_splits,
                   Z_CLAN = Z_CLAN,
                   HT = HT,
                   X1_BLP = X1_BLP,
                   X1_GATES = X1_GATES,
                   vcov_BLP = vcov_BLP,
                   vcov_GATES = vcov_GATES,
                   quantile_cutoffs = quantile_cutoffs,
                   diff_GATES = diff_GATES,
                   diff_CLAN = diff_CLAN,
                   prop_aux = prop_aux,
                   stratify = stratify,
                   significance_level = significance_level,
                   min_variation = min_variation,
                   parallel = parallel,
                   num_cores = num_cores,
                   seed = seed,
                   store_splits = store_splits,
                   store_learners = store_learners)


### 4. General results ----

## print
genML


## get the medians of the estimated  \Lambda and \bar{\Lambda} to find best learners
get_best(genML)

## Extract key results
cat("\nBLP Results (Best Linear Predictor):\n")
blp_results <- get_BLP(genML)
print(blp_results)

cat("\nGATES Results (Group Average Treatment Effects):\n")
gates_results <- get_GATES(genML)
print(gates_results)

results_BLP <- get_BLP(genML)
results_BLP # print method
plot(results_BLP) # plot method

results_GATES <- get_GATES(genML)
results_GATES # print method
plot(results_GATES) # plot method

results_CLAN_z1 <- get_CLAN(genML, variable = "abj")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "uni")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "a30")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "sup")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "rel")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "exp")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "man")
print(results_CLAN_z1)

results_CLAN_z1 <- get_CLAN(genML, variable = "old")
print(results_CLAN_z1)



# GENERATE CLAN RESULTS TABLE FOR MULTIPLE VARIABLES
library(dplyr)
library(tibble)
library(kableExtra)

# Define all variables you want to analyze
variables_list <- c("abj", "uni", "a30", "sup", "rel", "exp", "man", "old")

# Create descriptive labels for each variable (customize these based on your study)
variable_labels <- c(
  "abj" = "Abidjan",
  "uni" = "Tertiary education", 
  "a30" = "Age 30 or above",
  "sup" = "Supervisor position",
  "rel" = "Relationship with employer",
  "exp" = "Experience",
  "man" = "Male",
  "old" = "Incumbent employee (already in firm)"
)

### FUNCTION TO EXTRACT AND FORMAT CLAN RESULTS ----
extract_clan_results <- function(genML_obj, variable_name, variable_label) {
  
  # Get CLAN results
  clan_result <- get_CLAN(genML_obj, variable = variable_name)
  
  # Extract key values
  least_estimate <- clan_result$estimate[1]      # δ4 (25% least)
  least_ci_lower <- clan_result$confidence_interval[1, 1]
  least_ci_upper <- clan_result$confidence_interval[1, 2]
  
  most_estimate <- clan_result$estimate[4]     # δ4 (25% most)  
  most_ci_lower <- clan_result$confidence_interval[4, 1]
  most_ci_upper <- clan_result$confidence_interval[4, 2]
  
  diff_estimate <- clan_result$estimate[5]      # δ4 - δ1(difference)
  diff_ci_lower <- clan_result$confidence_interval[5, 1]
  diff_ci_upper <- clan_result$confidence_interval[5, 2]
  diff_pvalue <- clan_result$p_value[5]
  
  # Format results
  tibble(
    Variable = variable_label,
    `25% Least (δ1)` = sprintf("%.3f", least_estimate),
    `CI_Least` = sprintf("(%.3f, %.3f)", least_ci_lower, least_ci_upper),
    `25% Most (δ4)` = sprintf("%.3f", most_estimate),
    `CI_Most` = sprintf("(%.3f, %.3f)", most_ci_lower, most_ci_upper),
    `Difference (δ4 - δ1)` = sprintf("%.3f", diff_estimate),
    `CI_Diff` = sprintf("(%.3f, %.3f)", diff_ci_lower, diff_ci_upper),
    `P_value` = sprintf("[%.3f]", diff_pvalue)
  )
}

### EXTRACT RESULTS FOR ALL VARIABLES ----
cat("Extracting CLAN results for all variables...\n")

all_clan_results <- list()

for(i in seq_along(variables_list)) {
  var_name <- variables_list[i]
  var_label <- variable_labels[var_name]
  
  cat(paste("Processing variable:", var_name, "\n"))
  
  # Extract results
  clan_row <- extract_clan_results(genML, var_name, var_label)
  all_clan_results[[i]] <- clan_row
  
  # Also print individual results as you requested
  cat(paste("--- CLAN Results for:", var_name, "---\n"))
  results_CLAN_z1 <- get_CLAN(genML, variable = var_name)
  print(results_CLAN_z1)
  cat("\n")
}

# Combine all results
complete_clan_table <- bind_rows(all_clan_results)



### CREATE FORMATTED TABLE ----

# Method 1: Simple table with separate columns
cat("=== SIMPLE FORMAT ===\n")
print(complete_clan_table)

# Method 2: Combined format like your reference table
formatted_clan_table <- complete_clan_table %>%
  mutate(
    `25% Least (δ4)` = paste(`25% Least (δ1)`, `CI_Most`, sep = "\n"),
    `25% Most (δ1)` = paste(`25% Most (δ4)`, `CI_Least`, sep = "\n"),
    `Difference (δ4 - δ1)` = paste(`Difference (δ4 - δ1)`, `CI_Diff`, `P_value`, sep = "\n")
  ) %>%
  select(Variable, `25% Most (δ4)`, `25% Least (δ1)`, `Difference (δ4 - δ1)`)

cat("\n=== FORMATTED TABLE ===\n")
print(formatted_clan_table)

### CREATE LATEX TABLE FOR PUBLICATION ----

# With confidence intervals
latex_table_with_ci <- formatted_clan_table %>%
  kable(format = "latex",
        booktabs = TRUE,
        row.names = FALSE,
        escape = FALSE,
        align = "lccc",
        col.names = c("Variable", "25\\% Most\n(\\(\\delta_4\\))", 
                      "25\\% Least\n(\\(\\delta_1\\))", 
                      "Difference\n(\\(\\delta_4 - \\delta_1\\))"),
        caption = "CLAN Results: Heterogeneous Treatment Effects by Variable") %>%
  kable_styling(position = "center", latex_options = c("hold_position")) %>%
  add_header_above(c(" " = 1, "Random Forest" = 3)) %>%
  footnote(general = "Point estimates with 95\\% confidence intervals and p-values in brackets.",
           escape = FALSE)


cat("\n\n=== LATEX CODE (With CI) ===\n") 
cat(latex_table_with_ci)

